# # clear existing user defined variables
# for element in dir():
#     if element[0:2] != "__":
#         del globals()[element]

import os
import pickle
import argparse
from time import time

import numpy as np
import tensorflow as tf

from functions_preprocessing import english_standardization, create_train_eval_datasets
from models import build_model, CustomCheckpoint
#from matplotlib import pyplot as plt


# =================================================================
# =================================================================
# KEYWORD SPOTTING
# TRAINING CONVOLUTIONAL RECURRENT NEURAL NETWORK
# =================================================================
# =================================================================

# python kws_train.py -r result_model02 -m 2
# python kws_train.py -r result_model03 -m 3
# python kws_train.py -r result_model04 -m 4

# python kws_train.py -r result_model02_ds -m 2 -d data2
# python kws_train.py -r result_model02_noise -m 2 --noise_aug 0.5

# ----------------------------------------------------------
# Training Configurations and Feature Extraction Parameters
# ----------------------------------------------------------

parser = argparse.ArgumentParser()
parser.add_argument("-d","--datafolder",default='data2',type=str,help="Path to data folder - generated by `prepare_dataset.py`.")
parser.add_argument("-r","--resultfolder",default='result_02',type=str,help="Path to result folder - to be created.")

parser.add_argument("-f","--feature_type",default='mfcc',type=str,help="Feature type: `mfcc` or `melspec`.")
parser.add_argument("-m","--model_num",default=2,type=int,help="Choose model to train. Currently implemented from 1 to 6.")

parser.add_argument("--train_test_split",default=0.8,type=float,help="Proportion of the dataset used for training.")
parser.add_argument("--batch_size",default=64,type=int,help="Batch size for training.")
parser.add_argument("--noise_aug",default=0.0,type=float,help="Noise augmentation rate (from 0 to 1).")
parser.add_argument("--rng_seed",default=47,type=int,help="Seed for random dataset shuffle.")
parser.add_argument("--epochs",default=10,type=int,help="Num. of epochs for training.")

parser.add_argument("--frame_length",default=256,type=int,help="Num. of audio samples composing each feature frame.")

parser.add_argument("--lower_freq",default=80.0,type=float,help="Lower frequency for the mel-filter banks.")
parser.add_argument("--upper_freq",default=7600.0,type=float,help="Upper frequency for the mel-filter banks.")
parser.add_argument("--n_mel_bins",default=80,type=int,help="Num. of mel-filter banks.")
parser.add_argument("--frame_step",default=128,type=int,help="Num. of samples to shift between frames.")
parser.add_argument("--fft_length",default=256,type=int,help="FFT resolution.")

parser.add_argument("--n_mfcc_bins",default=13,type=int,help="Num. of mel-frequency cepstral coefficients.")
args = parser.parse_args()


# ----------------------------------------------------------
#  Creating Folder and Exporting Parameters
# ----------------------------------------------------------

descript = f"""TRAINING DESCRIPTION

TRAIN_TEST_SPLIT = {args.train_test_split}
BATCH_SIZE = {args.batch_size}
RNG_SEED = {args.rng_seed}

FRAME_LENGTH = {args.frame_length}
FRAME_STEP = {args.frame_step}
FFT_LENGTH = {args.fft_length}
N_MEL_BINS = {args.n_mel_bins}
N_MFCC_BINS = {args.n_mfcc_bins}
(only MFCC features, and no deltas)

MODEL = {args.model_num}
"""

train_params = {'feature_type': args.feature_type,
                'model_num': args.model_num,
                'train_test_split': args.train_test_split,
                'batch_size': args.batch_size,
                'rng_seed': args.rng_seed,
                'noise_aug': args.noise_aug,
                'frame_length': args.frame_length,
                'frame_step': args.frame_step,
                'fft_length': args.fft_length,
                'lower_freq': args.lower_freq,
                'upper_freq': args.upper_freq,
                'n_mel_bins': args.n_mel_bins,
                'n_mfcc_bins': args.n_mfcc_bins}

if not os.path.exists(args.datafolder):
    raise AssertionError('Specified datafolder does not exist.')
if os.path.exists(args.resultfolder):
    raise AssertionError('Result folder already exists. Make sure not to overwrite previous data.')
else:
    os.makedirs(args.resultfolder)

# save parameters and description text
with open(os.path.join(args.resultfolder,"parameters.pickle"), 'wb') as f:
    pickle.dump(train_params, f)
with open(os.path.join(args.resultfolder,'description.txt'), 'w') as f:
    f.write(descript)
with open(os.path.join(args.resultfolder,'hist.txt'), 'w') as f:
    f.write('epoch\tloss\t\tval_loss\n')

if args.feature_type=='mfcc':
    feat_dim = args.n_mfcc_bins
elif args.feature_type=='melspec':
    feat_dim = args.n_mel_bins
else:
    raise AssertionError('Feature has not been correctly selected.')

# ----------------------------------------------------------------

# load dataset parameters, and initialize text vectorizer

with open(os.path.join(args.datafolder,"parameters.pickle"), 'rb') as f:
    dataset_params = pickle.load(f)
with open(os.path.join(args.datafolder,"noise_aug.pickle"), 'rb') as f:
    noise_settings = pickle.load(f)
keywords = dataset_params['keywords']
num_kwd = dataset_params['num_kwd']

text_processor = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=english_standardization,
    max_tokens=None,
    vocabulary=keywords)


# =====================================================
print('(1) Load Dataset Into Lists')
# =====================================================

# (A) load datasets (input paths and target texts)

with open(os.path.join(args.datafolder,"speech_commands_dict.pickle"), 'rb') as f:
    data1 = pickle.load(f)
with open(os.path.join(args.datafolder,"speech_commands_edit_dict.pickle"), 'rb') as f:
    data2 = pickle.load(f)
with open(os.path.join(args.datafolder,"librispeech_dict.pickle"), 'rb') as f:
    data3 = pickle.load(f)


# (B) join data and shuffle

input_paths = data1['input_path'] + data2['input_path'] + data3['input_path']
target_texts = data1['target_text'] + data2['target_text'] + data3['target_text']

num_samples = len(input_paths)
np.random.seed(args.rng_seed)
p = np.random.permutation(num_samples)
input_paths = [ input_paths[i] for i in p ]
target_texts = [ target_texts[i] for i in p ]

num_train = int( num_samples * args.train_test_split )
print(f'Number of transcription pairs: {num_samples}')
print(f'Number of training samples: {num_train}')
print(f'Number of validation samples: {num_samples-num_train}')


# =====================================================
print('(2) Create Dataset Object')
# =====================================================

train_ds, valid_ds = create_train_eval_datasets(input_paths, target_texts,
    noise_settings=noise_settings, noise_prob=args.noise_aug,
    feature_type=args.feature_type, vectorizer=text_processor, sr=dataset_params['sampling_rate'],
    frame_len=args.frame_length, frame_hop=args.frame_step, fft_len=args.fft_length,
    num_mel_bins=args.n_mel_bins, lower_freq=args.lower_freq, upper_freq=args.upper_freq,
    num_mfcc=args.n_mfcc_bins, batch_size=args.batch_size, train_test_split=args.train_test_split)

print('Dataset examples')
t1 = time()
for pair in train_ds.take(1):
    for i in range(5):
        print(pair['input'][i].shape, pair['target'][i])
        # tensor_audio = tf.expand_dims( pair['input'][i] ,axis=1)
        # bin_wav = tf.audio.encode_wav(tensor_audio, 16000)
        # tf.io.write_file('test_{:02d}.wav'.format(i),bin_wav)
for pair in valid_ds.take(1):
    for i in range(5):
        print(pair['input'][i].shape, pair['target'][i])


# =====================================================
print('\n(3) Load Model and Train')
# =====================================================

# Get the model
model_train, _ = build_model(args.model_num, feat_dim, num_kwd)
model_train.summary()

# Checkpoint to save modell parameters and record performance
checkpointer = CustomCheckpoint(args.resultfolder)

model_train.fit(train_ds, validation_data=valid_ds,
        epochs=args.epochs, callbacks=[checkpointer])





# ==================================================================